{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# `FoldConstant`\n",
    "\n",
    "参考：`tvm/tests/python/relax/test_transform_fold_constant.py`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tvm\n",
    "import tvm.testing\n",
    "from tvm import relax\n",
    "import numpy as np\n",
    "\n",
    "import tvm.script\n",
    "from tvm.script import ir as I, tir as T, relax as R"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "def gen_mod(mod, name, binding):\n",
    "    \"\"\"处理 IR 模块的工具函数，主要完成三个功能：\n",
    "    1. 选择指定名称的 Relax 函数\n",
    "    2. 将选中函数重命名为 main 入口\n",
    "    3. 绑定常量参数\n",
    "    \n",
    "    Args:\n",
    "        mod: 原始 IR 模块，包含多个函数\n",
    "        name: 目标 Relax 函数名（需保留并重命名）\n",
    "        binding: 需要绑定的常量参数字典 {参数名: numpy数组}\n",
    "    \"\"\"\n",
    "    # 将numpy数组参数转换为TVM NDArray格式\n",
    "    binding = {k: tvm.nd.array(v) for k, v in binding.items()}\n",
    "    \n",
    "    funcs = {}  # 存储处理后的函数集合\n",
    "    \n",
    "    # 遍历模块中的所有函数\n",
    "    for k, v in mod.functions.items():\n",
    "        # 处理Relax函数：仅保留指定名称的函数\n",
    "        if isinstance(v, tvm.relax.Function):\n",
    "            if k.name_hint == name:\n",
    "                # 重命名为 main\n",
    "                gv = tvm.ir.GlobalVar(\"main\")\n",
    "                funcs[gv] = tvm.relax.Function(v.params, v.body, v.ret_struct_info).with_attr(\n",
    "                    \"global_symbol\", \"main\"\n",
    "                )\n",
    "        # 保留所有非Relax函数（如TIR原语函数）\n",
    "        else:\n",
    "            funcs[k] = v\n",
    "    \n",
    "    # 构建新模块并绑定常量参数\n",
    "    mod = tvm.IRModule(funcs)\n",
    "    return relax.transform.BindParams(\"main\", binding)(mod)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 折叠 `+1` 常量"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "tags": [
     "hide-output"
    ]
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div class=\"highlight\" style=\"background: \"><pre style=\"line-height: 125%;\"><span></span><span style=\"color: #007979; font-style: italic\"># from tvm.script import ir as I</span>\n",
       "<span style=\"color: #007979; font-style: italic\"># from tvm.script import tir as T</span>\n",
       "<span style=\"color: #007979; font-style: italic\"># from tvm.script import relax as R</span>\n",
       "\n",
       "<span style=\"color: #AA22FF\">@I</span><span style=\"color: #AA22FF; font-weight: bold\">.</span>ir_module\n",
       "<span style=\"color: #008000; font-weight: bold\">class</span> <span style=\"color: #0000FF; font-weight: bold\">Module</span>:\n",
       "    <span style=\"color: #AA22FF\">@T</span><span style=\"color: #AA22FF; font-weight: bold\">.</span>prim_func\n",
       "    <span style=\"color: #008000; font-weight: bold\">def</span> <span style=\"color: #0000FF\">addone</span>(A: T<span style=\"color: #AA22FF; font-weight: bold\">.</span>Buffer((<span style=\"color: #008000\">16</span>, <span style=\"color: #008000\">16</span>), <span style=\"color: #BA2121\">&quot;float32&quot;</span>), B: T<span style=\"color: #AA22FF; font-weight: bold\">.</span>Buffer((<span style=\"color: #008000\">16</span>, <span style=\"color: #008000\">16</span>), <span style=\"color: #BA2121\">&quot;float32&quot;</span>)):\n",
       "        <span style=\"color: #007979; font-style: italic\"># with T.block(&quot;root&quot;):</span>\n",
       "        <span style=\"color: #008000; font-weight: bold\">for</span> i, j <span style=\"color: #008000; font-weight: bold\">in</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>grid(<span style=\"color: #008000\">16</span>, <span style=\"color: #008000\">16</span>):\n",
       "            <span style=\"color: #008000; font-weight: bold\">with</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>block(<span style=\"color: #BA2121\">&quot;addone&quot;</span>):\n",
       "                vi, vj <span style=\"color: #AA22FF; font-weight: bold\">=</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>axis<span style=\"color: #AA22FF; font-weight: bold\">.</span>remap(<span style=\"color: #BA2121\">&quot;SS&quot;</span>, [i, j])\n",
       "                T<span style=\"color: #AA22FF; font-weight: bold\">.</span>reads(A[vi, vj])\n",
       "                T<span style=\"color: #AA22FF; font-weight: bold\">.</span>writes(B[vi, vj])\n",
       "                B[vi, vj] <span style=\"color: #AA22FF; font-weight: bold\">=</span> A[vi, vj] <span style=\"color: #AA22FF; font-weight: bold\">+</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>float32(<span style=\"color: #008000\">1.0</span>)\n",
       "\n",
       "    <span style=\"color: #AA22FF\">@R</span><span style=\"color: #AA22FF; font-weight: bold\">.</span>function\n",
       "    <span style=\"color: #008000; font-weight: bold\">def</span> <span style=\"color: #0000FF\">main</span>() <span style=\"color: #AA22FF; font-weight: bold\">-&gt;</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">16</span>, <span style=\"color: #008000\">16</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>):\n",
       "        cls <span style=\"color: #AA22FF; font-weight: bold\">=</span> Module\n",
       "        lv0 <span style=\"color: #AA22FF; font-weight: bold\">=</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>call_tir(cls<span style=\"color: #AA22FF; font-weight: bold\">.</span>addone, (metadata[<span style=\"color: #BA2121\">&quot;relax.expr.Constant&quot;</span>][<span style=\"color: #008000\">0</span>],), out_sinfo<span style=\"color: #AA22FF; font-weight: bold\">=</span>R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">16</span>, <span style=\"color: #008000\">16</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>))\n",
       "        <span style=\"color: #008000; font-weight: bold\">return</span> lv0\n",
       "\n",
       "<span style=\"color: #007979; font-style: italic\"># Metadata omitted. Use show_meta=True in script() method to show it.</span>\n",
       "</pre></div>\n"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "@tvm.script.ir_module\n",
    "class Module:\n",
    "    @T.prim_func\n",
    "    def addone(A: T.Buffer((16, 16), \"float32\"), B: T.Buffer((16, 16), \"float32\")) -> None:\n",
    "        for i, j in T.grid(16, 16):\n",
    "            with T.block(\"addone\"):\n",
    "                vi, vj = T.axis.remap(\"SS\", [i, j])\n",
    "                B[vi, vj] = A[vi, vj] + T.float32(1)\n",
    "\n",
    "    @R.function\n",
    "    def before(c0: R.Tensor((16, 16), \"float32\")):\n",
    "        cls = Module\n",
    "        lv0 = relax.call_tir(cls.addone, (c0,), R.Tensor((16, 16), dtype=\"float32\"))\n",
    "        return lv0\n",
    "\n",
    "c0_np = np.arange((16 * 16)).astype(\"float32\").reshape(16, 16)\n",
    "c1_np = c0_np + 1\n",
    "before = gen_mod(Module, \"before\", {\"c0\": c0_np})\n",
    "before.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div class=\"highlight\" style=\"background: \"><pre style=\"line-height: 125%;\"><span></span><span style=\"color: #007979; font-style: italic\"># from tvm.script import ir as I</span>\n",
       "<span style=\"color: #007979; font-style: italic\"># from tvm.script import tir as T</span>\n",
       "<span style=\"color: #007979; font-style: italic\"># from tvm.script import relax as R</span>\n",
       "\n",
       "<span style=\"color: #AA22FF\">@I</span><span style=\"color: #AA22FF; font-weight: bold\">.</span>ir_module\n",
       "<span style=\"color: #008000; font-weight: bold\">class</span> <span style=\"color: #0000FF; font-weight: bold\">Module</span>:\n",
       "    <span style=\"color: #AA22FF\">@T</span><span style=\"color: #AA22FF; font-weight: bold\">.</span>prim_func\n",
       "    <span style=\"color: #008000; font-weight: bold\">def</span> <span style=\"color: #0000FF\">addone</span>(A: T<span style=\"color: #AA22FF; font-weight: bold\">.</span>Buffer((<span style=\"color: #008000\">16</span>, <span style=\"color: #008000\">16</span>), <span style=\"color: #BA2121\">&quot;float32&quot;</span>), B: T<span style=\"color: #AA22FF; font-weight: bold\">.</span>Buffer((<span style=\"color: #008000\">16</span>, <span style=\"color: #008000\">16</span>), <span style=\"color: #BA2121\">&quot;float32&quot;</span>)):\n",
       "        <span style=\"color: #007979; font-style: italic\"># with T.block(&quot;root&quot;):</span>\n",
       "        <span style=\"color: #008000; font-weight: bold\">for</span> i, j <span style=\"color: #008000; font-weight: bold\">in</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>grid(<span style=\"color: #008000\">16</span>, <span style=\"color: #008000\">16</span>):\n",
       "            <span style=\"color: #008000; font-weight: bold\">with</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>block(<span style=\"color: #BA2121\">&quot;addone&quot;</span>):\n",
       "                vi, vj <span style=\"color: #AA22FF; font-weight: bold\">=</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>axis<span style=\"color: #AA22FF; font-weight: bold\">.</span>remap(<span style=\"color: #BA2121\">&quot;SS&quot;</span>, [i, j])\n",
       "                T<span style=\"color: #AA22FF; font-weight: bold\">.</span>reads(A[vi, vj])\n",
       "                T<span style=\"color: #AA22FF; font-weight: bold\">.</span>writes(B[vi, vj])\n",
       "                B[vi, vj] <span style=\"color: #AA22FF; font-weight: bold\">=</span> A[vi, vj] <span style=\"color: #AA22FF; font-weight: bold\">+</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>float32(<span style=\"color: #008000\">1.0</span>)\n",
       "\n",
       "    <span style=\"color: #AA22FF\">@R</span><span style=\"color: #AA22FF; font-weight: bold\">.</span>function\n",
       "    <span style=\"color: #008000; font-weight: bold\">def</span> <span style=\"color: #0000FF\">main</span>() <span style=\"color: #AA22FF; font-weight: bold\">-&gt;</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">16</span>, <span style=\"color: #008000\">16</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>):\n",
       "        <span style=\"color: #008000; font-weight: bold\">return</span> metadata[<span style=\"color: #BA2121\">&quot;relax.expr.Constant&quot;</span>][<span style=\"color: #008000\">0</span>]\n",
       "\n",
       "<span style=\"color: #007979; font-style: italic\"># Metadata omitted. Use show_meta=True in script() method to show it.</span>\n",
       "</pre></div>\n"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "after = relax.transform.FoldConstant()(before)\n",
    "after.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 折叠常量转的转置"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "@tvm.script.ir_module\n",
    "class Module:\n",
    "    @T.prim_func\n",
    "    def func(A: T.Buffer((2, 3), \"float32\"), B: T.Buffer((3, 2), \"float32\")) -> None:\n",
    "        for i, j in T.grid(3, 2):\n",
    "            with T.block(\"transpose\"):\n",
    "                vi, vj = T.axis.remap(\"SS\", [i, j])\n",
    "                B[vi, vj] = A[vj, vi]\n",
    "\n",
    "    @R.function\n",
    "    def before(c0: R.Tensor((2, 3), \"float32\")):\n",
    "        cls = Module\n",
    "        lv0 = relax.call_tir(cls.func, (c0,), R.Tensor((3, 2), dtype=\"float32\"))\n",
    "        return lv0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "tags": [
     "hide-output"
    ]
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div class=\"highlight\" style=\"background: \"><pre style=\"line-height: 125%;\"><span></span><span style=\"color: #007979; font-style: italic\"># from tvm.script import ir as I</span>\n",
       "<span style=\"color: #007979; font-style: italic\"># from tvm.script import tir as T</span>\n",
       "<span style=\"color: #007979; font-style: italic\"># from tvm.script import relax as R</span>\n",
       "\n",
       "<span style=\"color: #AA22FF\">@I</span><span style=\"color: #AA22FF; font-weight: bold\">.</span>ir_module\n",
       "<span style=\"color: #008000; font-weight: bold\">class</span> <span style=\"color: #0000FF; font-weight: bold\">Module</span>:\n",
       "    <span style=\"color: #AA22FF\">@T</span><span style=\"color: #AA22FF; font-weight: bold\">.</span>prim_func\n",
       "    <span style=\"color: #008000; font-weight: bold\">def</span> <span style=\"color: #0000FF\">func</span>(A: T<span style=\"color: #AA22FF; font-weight: bold\">.</span>Buffer((<span style=\"color: #008000\">2</span>, <span style=\"color: #008000\">3</span>), <span style=\"color: #BA2121\">&quot;float32&quot;</span>), B: T<span style=\"color: #AA22FF; font-weight: bold\">.</span>Buffer((<span style=\"color: #008000\">3</span>, <span style=\"color: #008000\">2</span>), <span style=\"color: #BA2121\">&quot;float32&quot;</span>)):\n",
       "        <span style=\"color: #007979; font-style: italic\"># with T.block(&quot;root&quot;):</span>\n",
       "        <span style=\"color: #008000; font-weight: bold\">for</span> i, j <span style=\"color: #008000; font-weight: bold\">in</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>grid(<span style=\"color: #008000\">3</span>, <span style=\"color: #008000\">2</span>):\n",
       "            <span style=\"color: #008000; font-weight: bold\">with</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>block(<span style=\"color: #BA2121\">&quot;transpose&quot;</span>):\n",
       "                vi, vj <span style=\"color: #AA22FF; font-weight: bold\">=</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>axis<span style=\"color: #AA22FF; font-weight: bold\">.</span>remap(<span style=\"color: #BA2121\">&quot;SS&quot;</span>, [i, j])\n",
       "                T<span style=\"color: #AA22FF; font-weight: bold\">.</span>reads(A[vj, vi])\n",
       "                T<span style=\"color: #AA22FF; font-weight: bold\">.</span>writes(B[vi, vj])\n",
       "                B[vi, vj] <span style=\"color: #AA22FF; font-weight: bold\">=</span> A[vj, vi]\n",
       "\n",
       "    <span style=\"color: #AA22FF\">@R</span><span style=\"color: #AA22FF; font-weight: bold\">.</span>function\n",
       "    <span style=\"color: #008000; font-weight: bold\">def</span> <span style=\"color: #0000FF\">main</span>() <span style=\"color: #AA22FF; font-weight: bold\">-&gt;</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">3</span>, <span style=\"color: #008000\">2</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>):\n",
       "        cls <span style=\"color: #AA22FF; font-weight: bold\">=</span> Module\n",
       "        lv0 <span style=\"color: #AA22FF; font-weight: bold\">=</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>call_tir(cls<span style=\"color: #AA22FF; font-weight: bold\">.</span>func, (metadata[<span style=\"color: #BA2121\">&quot;relax.expr.Constant&quot;</span>][<span style=\"color: #008000\">0</span>],), out_sinfo<span style=\"color: #AA22FF; font-weight: bold\">=</span>R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">3</span>, <span style=\"color: #008000\">2</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>))\n",
       "        <span style=\"color: #008000; font-weight: bold\">return</span> lv0\n",
       "\n",
       "<span style=\"color: #007979; font-style: italic\"># Metadata omitted. Use show_meta=True in script() method to show it.</span>\n",
       "</pre></div>\n"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<div class=\"highlight\" style=\"background: \"><pre style=\"line-height: 125%;\"><span></span><span style=\"color: #007979; font-style: italic\"># from tvm.script import ir as I</span>\n",
       "<span style=\"color: #007979; font-style: italic\"># from tvm.script import tir as T</span>\n",
       "<span style=\"color: #007979; font-style: italic\"># from tvm.script import relax as R</span>\n",
       "\n",
       "<span style=\"color: #AA22FF\">@I</span><span style=\"color: #AA22FF; font-weight: bold\">.</span>ir_module\n",
       "<span style=\"color: #008000; font-weight: bold\">class</span> <span style=\"color: #0000FF; font-weight: bold\">Module</span>:\n",
       "    <span style=\"color: #AA22FF\">@T</span><span style=\"color: #AA22FF; font-weight: bold\">.</span>prim_func\n",
       "    <span style=\"color: #008000; font-weight: bold\">def</span> <span style=\"color: #0000FF\">func</span>(A: T<span style=\"color: #AA22FF; font-weight: bold\">.</span>Buffer((<span style=\"color: #008000\">2</span>, <span style=\"color: #008000\">3</span>), <span style=\"color: #BA2121\">&quot;float32&quot;</span>), B: T<span style=\"color: #AA22FF; font-weight: bold\">.</span>Buffer((<span style=\"color: #008000\">3</span>, <span style=\"color: #008000\">2</span>), <span style=\"color: #BA2121\">&quot;float32&quot;</span>)):\n",
       "        <span style=\"color: #007979; font-style: italic\"># with T.block(&quot;root&quot;):</span>\n",
       "        <span style=\"color: #008000; font-weight: bold\">for</span> i, j <span style=\"color: #008000; font-weight: bold\">in</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>grid(<span style=\"color: #008000\">3</span>, <span style=\"color: #008000\">2</span>):\n",
       "            <span style=\"color: #008000; font-weight: bold\">with</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>block(<span style=\"color: #BA2121\">&quot;transpose&quot;</span>):\n",
       "                vi, vj <span style=\"color: #AA22FF; font-weight: bold\">=</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>axis<span style=\"color: #AA22FF; font-weight: bold\">.</span>remap(<span style=\"color: #BA2121\">&quot;SS&quot;</span>, [i, j])\n",
       "                T<span style=\"color: #AA22FF; font-weight: bold\">.</span>reads(A[vj, vi])\n",
       "                T<span style=\"color: #AA22FF; font-weight: bold\">.</span>writes(B[vi, vj])\n",
       "                B[vi, vj] <span style=\"color: #AA22FF; font-weight: bold\">=</span> A[vj, vi]\n",
       "\n",
       "    <span style=\"color: #AA22FF\">@R</span><span style=\"color: #AA22FF; font-weight: bold\">.</span>function\n",
       "    <span style=\"color: #008000; font-weight: bold\">def</span> <span style=\"color: #0000FF\">main</span>() <span style=\"color: #AA22FF; font-weight: bold\">-&gt;</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">3</span>, <span style=\"color: #008000\">2</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>):\n",
       "        <span style=\"color: #008000; font-weight: bold\">return</span> metadata[<span style=\"color: #BA2121\">&quot;relax.expr.Constant&quot;</span>][<span style=\"color: #008000\">0</span>]\n",
       "\n",
       "<span style=\"color: #007979; font-style: italic\"># Metadata omitted. Use show_meta=True in script() method to show it.</span>\n",
       "</pre></div>\n"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "c0_np = np.arange(2 * 3).astype(\"float32\").reshape(2, 3)\n",
    "c1_np = c0_np.T\n",
    "before = gen_mod(Module, \"before\", {\"c0\": c0_np})\n",
    "before.show()\n",
    "after = relax.transform.FoldConstant()(before)\n",
    "after.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## two_hop_addone"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "@tvm.script.ir_module\n",
    "class Module:\n",
    "    @T.prim_func\n",
    "    def addone(A: T.Buffer((2, 2), \"float32\"), B: T.Buffer((2, 2), \"float32\")) -> None:\n",
    "        for i, j in T.grid(2, 2):\n",
    "            with T.block(\"addone\"):\n",
    "                vi, vj = T.axis.remap(\"SS\", [i, j])\n",
    "                B[vi, vj] = A[vi, vj] + T.float32(1)\n",
    "\n",
    "    @R.function\n",
    "    def before(c0: R.Tensor((2, 2), \"float32\")):\n",
    "        cls = Module\n",
    "        lv0 = relax.call_tir(cls.addone, (c0,), R.Tensor((2, 2), dtype=\"float32\"))\n",
    "        lv1 = relax.call_tir(cls.addone, (lv0,), R.Tensor((2, 2), dtype=\"float32\"))\n",
    "        return lv1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "tags": [
     "hide-output"
    ]
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div class=\"highlight\" style=\"background: \"><pre style=\"line-height: 125%;\"><span></span><span style=\"color: #007979; font-style: italic\"># from tvm.script import ir as I</span>\n",
       "<span style=\"color: #007979; font-style: italic\"># from tvm.script import tir as T</span>\n",
       "<span style=\"color: #007979; font-style: italic\"># from tvm.script import relax as R</span>\n",
       "\n",
       "<span style=\"color: #AA22FF\">@I</span><span style=\"color: #AA22FF; font-weight: bold\">.</span>ir_module\n",
       "<span style=\"color: #008000; font-weight: bold\">class</span> <span style=\"color: #0000FF; font-weight: bold\">Module</span>:\n",
       "    <span style=\"color: #AA22FF\">@T</span><span style=\"color: #AA22FF; font-weight: bold\">.</span>prim_func\n",
       "    <span style=\"color: #008000; font-weight: bold\">def</span> <span style=\"color: #0000FF\">addone</span>(A: T<span style=\"color: #AA22FF; font-weight: bold\">.</span>Buffer((<span style=\"color: #008000\">2</span>, <span style=\"color: #008000\">2</span>), <span style=\"color: #BA2121\">&quot;float32&quot;</span>), B: T<span style=\"color: #AA22FF; font-weight: bold\">.</span>Buffer((<span style=\"color: #008000\">2</span>, <span style=\"color: #008000\">2</span>), <span style=\"color: #BA2121\">&quot;float32&quot;</span>)):\n",
       "        <span style=\"color: #007979; font-style: italic\"># with T.block(&quot;root&quot;):</span>\n",
       "        <span style=\"color: #008000; font-weight: bold\">for</span> i, j <span style=\"color: #008000; font-weight: bold\">in</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>grid(<span style=\"color: #008000\">2</span>, <span style=\"color: #008000\">2</span>):\n",
       "            <span style=\"color: #008000; font-weight: bold\">with</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>block(<span style=\"color: #BA2121\">&quot;addone&quot;</span>):\n",
       "                vi, vj <span style=\"color: #AA22FF; font-weight: bold\">=</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>axis<span style=\"color: #AA22FF; font-weight: bold\">.</span>remap(<span style=\"color: #BA2121\">&quot;SS&quot;</span>, [i, j])\n",
       "                T<span style=\"color: #AA22FF; font-weight: bold\">.</span>reads(A[vi, vj])\n",
       "                T<span style=\"color: #AA22FF; font-weight: bold\">.</span>writes(B[vi, vj])\n",
       "                B[vi, vj] <span style=\"color: #AA22FF; font-weight: bold\">=</span> A[vi, vj] <span style=\"color: #AA22FF; font-weight: bold\">+</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>float32(<span style=\"color: #008000\">1.0</span>)\n",
       "\n",
       "    <span style=\"color: #AA22FF\">@R</span><span style=\"color: #AA22FF; font-weight: bold\">.</span>function\n",
       "    <span style=\"color: #008000; font-weight: bold\">def</span> <span style=\"color: #0000FF\">main</span>() <span style=\"color: #AA22FF; font-weight: bold\">-&gt;</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">2</span>, <span style=\"color: #008000\">2</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>):\n",
       "        cls <span style=\"color: #AA22FF; font-weight: bold\">=</span> Module\n",
       "        lv0 <span style=\"color: #AA22FF; font-weight: bold\">=</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>call_tir(cls<span style=\"color: #AA22FF; font-weight: bold\">.</span>addone, (metadata[<span style=\"color: #BA2121\">&quot;relax.expr.Constant&quot;</span>][<span style=\"color: #008000\">0</span>],), out_sinfo<span style=\"color: #AA22FF; font-weight: bold\">=</span>R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">2</span>, <span style=\"color: #008000\">2</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>))\n",
       "        lv1 <span style=\"color: #AA22FF; font-weight: bold\">=</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>call_tir(cls<span style=\"color: #AA22FF; font-weight: bold\">.</span>addone, (lv0,), out_sinfo<span style=\"color: #AA22FF; font-weight: bold\">=</span>R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">2</span>, <span style=\"color: #008000\">2</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>))\n",
       "        <span style=\"color: #008000; font-weight: bold\">return</span> lv1\n",
       "\n",
       "<span style=\"color: #007979; font-style: italic\"># Metadata omitted. Use show_meta=True in script() method to show it.</span>\n",
       "</pre></div>\n"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<div class=\"highlight\" style=\"background: \"><pre style=\"line-height: 125%;\"><span></span><span style=\"color: #007979; font-style: italic\"># from tvm.script import ir as I</span>\n",
       "<span style=\"color: #007979; font-style: italic\"># from tvm.script import tir as T</span>\n",
       "<span style=\"color: #007979; font-style: italic\"># from tvm.script import relax as R</span>\n",
       "\n",
       "<span style=\"color: #AA22FF\">@I</span><span style=\"color: #AA22FF; font-weight: bold\">.</span>ir_module\n",
       "<span style=\"color: #008000; font-weight: bold\">class</span> <span style=\"color: #0000FF; font-weight: bold\">Module</span>:\n",
       "    <span style=\"color: #AA22FF\">@T</span><span style=\"color: #AA22FF; font-weight: bold\">.</span>prim_func\n",
       "    <span style=\"color: #008000; font-weight: bold\">def</span> <span style=\"color: #0000FF\">addone</span>(A: T<span style=\"color: #AA22FF; font-weight: bold\">.</span>Buffer((<span style=\"color: #008000\">2</span>, <span style=\"color: #008000\">2</span>), <span style=\"color: #BA2121\">&quot;float32&quot;</span>), B: T<span style=\"color: #AA22FF; font-weight: bold\">.</span>Buffer((<span style=\"color: #008000\">2</span>, <span style=\"color: #008000\">2</span>), <span style=\"color: #BA2121\">&quot;float32&quot;</span>)):\n",
       "        <span style=\"color: #007979; font-style: italic\"># with T.block(&quot;root&quot;):</span>\n",
       "        <span style=\"color: #008000; font-weight: bold\">for</span> i, j <span style=\"color: #008000; font-weight: bold\">in</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>grid(<span style=\"color: #008000\">2</span>, <span style=\"color: #008000\">2</span>):\n",
       "            <span style=\"color: #008000; font-weight: bold\">with</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>block(<span style=\"color: #BA2121\">&quot;addone&quot;</span>):\n",
       "                vi, vj <span style=\"color: #AA22FF; font-weight: bold\">=</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>axis<span style=\"color: #AA22FF; font-weight: bold\">.</span>remap(<span style=\"color: #BA2121\">&quot;SS&quot;</span>, [i, j])\n",
       "                T<span style=\"color: #AA22FF; font-weight: bold\">.</span>reads(A[vi, vj])\n",
       "                T<span style=\"color: #AA22FF; font-weight: bold\">.</span>writes(B[vi, vj])\n",
       "                B[vi, vj] <span style=\"color: #AA22FF; font-weight: bold\">=</span> A[vi, vj] <span style=\"color: #AA22FF; font-weight: bold\">+</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>float32(<span style=\"color: #008000\">1.0</span>)\n",
       "\n",
       "    <span style=\"color: #AA22FF\">@R</span><span style=\"color: #AA22FF; font-weight: bold\">.</span>function\n",
       "    <span style=\"color: #008000; font-weight: bold\">def</span> <span style=\"color: #0000FF\">main</span>() <span style=\"color: #AA22FF; font-weight: bold\">-&gt;</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">2</span>, <span style=\"color: #008000\">2</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>):\n",
       "        <span style=\"color: #008000; font-weight: bold\">return</span> metadata[<span style=\"color: #BA2121\">&quot;relax.expr.Constant&quot;</span>][<span style=\"color: #008000\">0</span>]\n",
       "\n",
       "<span style=\"color: #007979; font-style: italic\"># Metadata omitted. Use show_meta=True in script() method to show it.</span>\n",
       "</pre></div>\n"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "c0_np = np.arange((2 * 2)).astype(\"float32\").reshape(2, 2)\n",
    "c1_np = c0_np + 1\n",
    "c2_np = c1_np + 1\n",
    "before = gen_mod(Module, \"before\", {\"c0\": c0_np})\n",
    "before.show()\n",
    "after = relax.transform.FoldConstant()(before)\n",
    "after.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 折叠 dataflow"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "@tvm.script.ir_module\n",
    "class Module:\n",
    "    @T.prim_func\n",
    "    def identity(A: T.Buffer((16, 16), \"float32\"), B: T.Buffer((16, 16), \"float32\")) -> None:\n",
    "        for i, j in T.grid(16, 16):\n",
    "            with T.block(\"identity\"):\n",
    "                vi, vj = T.axis.remap(\"SS\", [i, j])\n",
    "                B[vi, vj] = A[vi, vj]\n",
    "\n",
    "    @R.function\n",
    "    def before(c0: R.Tensor((16, 16), \"float32\")):\n",
    "        cls = Module\n",
    "        with R.dataflow():\n",
    "            gv0 = relax.call_tir(cls.identity, (c0,), R.Tensor((16, 16), dtype=\"float32\"))\n",
    "            R.output(gv0)\n",
    "        return gv0\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div class=\"highlight\" style=\"background: \"><pre style=\"line-height: 125%;\"><span></span><span style=\"color: #007979; font-style: italic\"># from tvm.script import ir as I</span>\n",
       "<span style=\"color: #007979; font-style: italic\"># from tvm.script import tir as T</span>\n",
       "<span style=\"color: #007979; font-style: italic\"># from tvm.script import relax as R</span>\n",
       "\n",
       "<span style=\"color: #AA22FF\">@I</span><span style=\"color: #AA22FF; font-weight: bold\">.</span>ir_module\n",
       "<span style=\"color: #008000; font-weight: bold\">class</span> <span style=\"color: #0000FF; font-weight: bold\">Module</span>:\n",
       "    <span style=\"color: #AA22FF\">@T</span><span style=\"color: #AA22FF; font-weight: bold\">.</span>prim_func\n",
       "    <span style=\"color: #008000; font-weight: bold\">def</span> <span style=\"color: #0000FF\">identity</span>(A: T<span style=\"color: #AA22FF; font-weight: bold\">.</span>Buffer((<span style=\"color: #008000\">16</span>, <span style=\"color: #008000\">16</span>), <span style=\"color: #BA2121\">&quot;float32&quot;</span>), B: T<span style=\"color: #AA22FF; font-weight: bold\">.</span>Buffer((<span style=\"color: #008000\">16</span>, <span style=\"color: #008000\">16</span>), <span style=\"color: #BA2121\">&quot;float32&quot;</span>)):\n",
       "        <span style=\"color: #007979; font-style: italic\"># with T.block(&quot;root&quot;):</span>\n",
       "        <span style=\"color: #008000; font-weight: bold\">for</span> i, j <span style=\"color: #008000; font-weight: bold\">in</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>grid(<span style=\"color: #008000\">16</span>, <span style=\"color: #008000\">16</span>):\n",
       "            <span style=\"color: #008000; font-weight: bold\">with</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>block(<span style=\"color: #BA2121\">&quot;identity&quot;</span>):\n",
       "                vi, vj <span style=\"color: #AA22FF; font-weight: bold\">=</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>axis<span style=\"color: #AA22FF; font-weight: bold\">.</span>remap(<span style=\"color: #BA2121\">&quot;SS&quot;</span>, [i, j])\n",
       "                T<span style=\"color: #AA22FF; font-weight: bold\">.</span>reads(A[vi, vj])\n",
       "                T<span style=\"color: #AA22FF; font-weight: bold\">.</span>writes(B[vi, vj])\n",
       "                B[vi, vj] <span style=\"color: #AA22FF; font-weight: bold\">=</span> A[vi, vj]\n",
       "\n",
       "    <span style=\"color: #AA22FF\">@R</span><span style=\"color: #AA22FF; font-weight: bold\">.</span>function\n",
       "    <span style=\"color: #008000; font-weight: bold\">def</span> <span style=\"color: #0000FF\">main</span>() <span style=\"color: #AA22FF; font-weight: bold\">-&gt;</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">16</span>, <span style=\"color: #008000\">16</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>):\n",
       "        cls <span style=\"color: #AA22FF; font-weight: bold\">=</span> Module\n",
       "        <span style=\"color: #008000; font-weight: bold\">with</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>dataflow():\n",
       "            gv0 <span style=\"color: #AA22FF; font-weight: bold\">=</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>call_tir(cls<span style=\"color: #AA22FF; font-weight: bold\">.</span>identity, (metadata[<span style=\"color: #BA2121\">&quot;relax.expr.Constant&quot;</span>][<span style=\"color: #008000\">0</span>],), out_sinfo<span style=\"color: #AA22FF; font-weight: bold\">=</span>R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">16</span>, <span style=\"color: #008000\">16</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>))\n",
       "            R<span style=\"color: #AA22FF; font-weight: bold\">.</span>output(gv0)\n",
       "        <span style=\"color: #008000; font-weight: bold\">return</span> gv0\n",
       "\n",
       "<span style=\"color: #007979; font-style: italic\"># Metadata omitted. Use show_meta=True in script() method to show it.</span>\n",
       "</pre></div>\n"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<div class=\"highlight\" style=\"background: \"><pre style=\"line-height: 125%;\"><span></span><span style=\"color: #007979; font-style: italic\"># from tvm.script import ir as I</span>\n",
       "<span style=\"color: #007979; font-style: italic\"># from tvm.script import tir as T</span>\n",
       "<span style=\"color: #007979; font-style: italic\"># from tvm.script import relax as R</span>\n",
       "\n",
       "<span style=\"color: #AA22FF\">@I</span><span style=\"color: #AA22FF; font-weight: bold\">.</span>ir_module\n",
       "<span style=\"color: #008000; font-weight: bold\">class</span> <span style=\"color: #0000FF; font-weight: bold\">Module</span>:\n",
       "    <span style=\"color: #AA22FF\">@T</span><span style=\"color: #AA22FF; font-weight: bold\">.</span>prim_func\n",
       "    <span style=\"color: #008000; font-weight: bold\">def</span> <span style=\"color: #0000FF\">identity</span>(A: T<span style=\"color: #AA22FF; font-weight: bold\">.</span>Buffer((<span style=\"color: #008000\">16</span>, <span style=\"color: #008000\">16</span>), <span style=\"color: #BA2121\">&quot;float32&quot;</span>), B: T<span style=\"color: #AA22FF; font-weight: bold\">.</span>Buffer((<span style=\"color: #008000\">16</span>, <span style=\"color: #008000\">16</span>), <span style=\"color: #BA2121\">&quot;float32&quot;</span>)):\n",
       "        <span style=\"color: #007979; font-style: italic\"># with T.block(&quot;root&quot;):</span>\n",
       "        <span style=\"color: #008000; font-weight: bold\">for</span> i, j <span style=\"color: #008000; font-weight: bold\">in</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>grid(<span style=\"color: #008000\">16</span>, <span style=\"color: #008000\">16</span>):\n",
       "            <span style=\"color: #008000; font-weight: bold\">with</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>block(<span style=\"color: #BA2121\">&quot;identity&quot;</span>):\n",
       "                vi, vj <span style=\"color: #AA22FF; font-weight: bold\">=</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>axis<span style=\"color: #AA22FF; font-weight: bold\">.</span>remap(<span style=\"color: #BA2121\">&quot;SS&quot;</span>, [i, j])\n",
       "                T<span style=\"color: #AA22FF; font-weight: bold\">.</span>reads(A[vi, vj])\n",
       "                T<span style=\"color: #AA22FF; font-weight: bold\">.</span>writes(B[vi, vj])\n",
       "                B[vi, vj] <span style=\"color: #AA22FF; font-weight: bold\">=</span> A[vi, vj]\n",
       "\n",
       "    <span style=\"color: #AA22FF\">@R</span><span style=\"color: #AA22FF; font-weight: bold\">.</span>function\n",
       "    <span style=\"color: #008000; font-weight: bold\">def</span> <span style=\"color: #0000FF\">main</span>() <span style=\"color: #AA22FF; font-weight: bold\">-&gt;</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">16</span>, <span style=\"color: #008000\">16</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>):\n",
       "        <span style=\"color: #008000; font-weight: bold\">return</span> metadata[<span style=\"color: #BA2121\">&quot;relax.expr.Constant&quot;</span>][<span style=\"color: #008000\">0</span>]\n",
       "\n",
       "<span style=\"color: #007979; font-style: italic\"># Metadata omitted. Use show_meta=True in script() method to show it.</span>\n",
       "</pre></div>\n"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "c0_np = np.arange((16 * 16)).astype(\"float32\").reshape(16, 16)\n",
    "c1_np = c0_np\n",
    "before = gen_mod(Module, \"before\", {\"c0\": c0_np})\n",
    "before.show()\n",
    "after = relax.transform.FoldConstant()(before)\n",
    "after.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## fold_mixed_case"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "@tvm.script.ir_module\n",
    "class Module:\n",
    "    # TIR function can handle different cases.\n",
    "    @T.prim_func\n",
    "    def addone(a: T.handle, b: T.handle) -> None:\n",
    "        n = T.int32()\n",
    "        m = T.int32()\n",
    "        A = T.match_buffer(a, (n, m))\n",
    "        B = T.match_buffer(b, (n, m))\n",
    "        for i, j in T.grid(n, m):\n",
    "            with T.block(\"addone\"):\n",
    "                vi, vj = T.axis.remap(\"SS\", [i, j])\n",
    "                B[vi, vj] = A[vi, vj] + T.float32(1)\n",
    "\n",
    "    @T.prim_func\n",
    "    def sub(\n",
    "        A: T.Buffer((16, 16), \"float32\"),\n",
    "        B: T.Buffer((16, 16), \"float32\"),\n",
    "        C: T.Buffer((16, 16), \"float32\"),\n",
    "    ) -> None:\n",
    "        for i, j in T.grid(16, 16):\n",
    "            with T.block(\"sub\"):\n",
    "                vi, vj = T.axis.remap(\"SS\", [i, j])\n",
    "                C[vi, vj] = A[vi, vj] - B[vi, vj]\n",
    "\n",
    "    @R.function\n",
    "    def before(c0: R.Tensor((16, 16), \"float32\"), x: R.Tensor(\"float32\", ndim=2)):\n",
    "        n, m = T.int64(), T.int64()\n",
    "        cls = Module\n",
    "        x0 = R.match_cast(x, R.Tensor((n, m), \"float32\"))\n",
    "        # this line cannot be folded because n is unknown\n",
    "        lv0 = relax.call_tir(cls.addone, (c0,), R.Tensor((n, 16), dtype=\"float32\"))\n",
    "        # this line can be folded\n",
    "        lv1 = relax.call_tir(cls.addone, (c0,), R.Tensor((16, 16), dtype=\"float32\"))\n",
    "        # this line can be folded because all inputs are const\n",
    "        lv2 = relax.call_tir(cls.sub, (c0, lv1), R.Tensor((16, 16), dtype=\"float32\"))\n",
    "        # this line can not be folded because x's shape is unknown\n",
    "        lv3 = relax.call_tir(cls.sub, (lv2, x), R.Tensor((16, 16), dtype=\"float32\"))\n",
    "        return (lv0, lv3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "tags": [
     "hide-output"
    ]
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div class=\"highlight\" style=\"background: \"><pre style=\"line-height: 125%;\"><span></span><span style=\"color: #007979; font-style: italic\"># from tvm.script import ir as I</span>\n",
       "<span style=\"color: #007979; font-style: italic\"># from tvm.script import tir as T</span>\n",
       "<span style=\"color: #007979; font-style: italic\"># from tvm.script import relax as R</span>\n",
       "\n",
       "<span style=\"color: #AA22FF\">@I</span><span style=\"color: #AA22FF; font-weight: bold\">.</span>ir_module\n",
       "<span style=\"color: #008000; font-weight: bold\">class</span> <span style=\"color: #0000FF; font-weight: bold\">Module</span>:\n",
       "    <span style=\"color: #AA22FF\">@T</span><span style=\"color: #AA22FF; font-weight: bold\">.</span>prim_func\n",
       "    <span style=\"color: #008000; font-weight: bold\">def</span> <span style=\"color: #0000FF\">identity</span>(A: T<span style=\"color: #AA22FF; font-weight: bold\">.</span>Buffer((<span style=\"color: #008000\">16</span>, <span style=\"color: #008000\">16</span>), <span style=\"color: #BA2121\">&quot;float32&quot;</span>), B: T<span style=\"color: #AA22FF; font-weight: bold\">.</span>Buffer((<span style=\"color: #008000\">16</span>, <span style=\"color: #008000\">16</span>), <span style=\"color: #BA2121\">&quot;float32&quot;</span>)):\n",
       "        <span style=\"color: #007979; font-style: italic\"># with T.block(&quot;root&quot;):</span>\n",
       "        <span style=\"color: #008000; font-weight: bold\">for</span> i, j <span style=\"color: #008000; font-weight: bold\">in</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>grid(<span style=\"color: #008000\">16</span>, <span style=\"color: #008000\">16</span>):\n",
       "            <span style=\"color: #008000; font-weight: bold\">with</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>block(<span style=\"color: #BA2121\">&quot;identity&quot;</span>):\n",
       "                vi, vj <span style=\"color: #AA22FF; font-weight: bold\">=</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>axis<span style=\"color: #AA22FF; font-weight: bold\">.</span>remap(<span style=\"color: #BA2121\">&quot;SS&quot;</span>, [i, j])\n",
       "                T<span style=\"color: #AA22FF; font-weight: bold\">.</span>reads(A[vi, vj])\n",
       "                T<span style=\"color: #AA22FF; font-weight: bold\">.</span>writes(B[vi, vj])\n",
       "                B[vi, vj] <span style=\"color: #AA22FF; font-weight: bold\">=</span> A[vi, vj]\n",
       "\n",
       "    <span style=\"color: #AA22FF\">@R</span><span style=\"color: #AA22FF; font-weight: bold\">.</span>function\n",
       "    <span style=\"color: #008000; font-weight: bold\">def</span> <span style=\"color: #0000FF\">main</span>() <span style=\"color: #AA22FF; font-weight: bold\">-&gt;</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">16</span>, <span style=\"color: #008000\">16</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>):\n",
       "        cls <span style=\"color: #AA22FF; font-weight: bold\">=</span> Module\n",
       "        <span style=\"color: #008000; font-weight: bold\">with</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>dataflow():\n",
       "            gv0 <span style=\"color: #AA22FF; font-weight: bold\">=</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>call_tir(cls<span style=\"color: #AA22FF; font-weight: bold\">.</span>identity, (metadata[<span style=\"color: #BA2121\">&quot;relax.expr.Constant&quot;</span>][<span style=\"color: #008000\">0</span>],), out_sinfo<span style=\"color: #AA22FF; font-weight: bold\">=</span>R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">16</span>, <span style=\"color: #008000\">16</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>))\n",
       "            R<span style=\"color: #AA22FF; font-weight: bold\">.</span>output(gv0)\n",
       "        <span style=\"color: #008000; font-weight: bold\">return</span> gv0\n",
       "\n",
       "<span style=\"color: #007979; font-style: italic\"># Metadata omitted. Use show_meta=True in script() method to show it.</span>\n",
       "</pre></div>\n"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<div class=\"highlight\" style=\"background: \"><pre style=\"line-height: 125%;\"><span></span><span style=\"color: #007979; font-style: italic\"># from tvm.script import ir as I</span>\n",
       "<span style=\"color: #007979; font-style: italic\"># from tvm.script import tir as T</span>\n",
       "<span style=\"color: #007979; font-style: italic\"># from tvm.script import relax as R</span>\n",
       "\n",
       "<span style=\"color: #AA22FF\">@I</span><span style=\"color: #AA22FF; font-weight: bold\">.</span>ir_module\n",
       "<span style=\"color: #008000; font-weight: bold\">class</span> <span style=\"color: #0000FF; font-weight: bold\">Module</span>:\n",
       "    <span style=\"color: #AA22FF\">@T</span><span style=\"color: #AA22FF; font-weight: bold\">.</span>prim_func\n",
       "    <span style=\"color: #008000; font-weight: bold\">def</span> <span style=\"color: #0000FF\">identity</span>(A: T<span style=\"color: #AA22FF; font-weight: bold\">.</span>Buffer((<span style=\"color: #008000\">16</span>, <span style=\"color: #008000\">16</span>), <span style=\"color: #BA2121\">&quot;float32&quot;</span>), B: T<span style=\"color: #AA22FF; font-weight: bold\">.</span>Buffer((<span style=\"color: #008000\">16</span>, <span style=\"color: #008000\">16</span>), <span style=\"color: #BA2121\">&quot;float32&quot;</span>)):\n",
       "        <span style=\"color: #007979; font-style: italic\"># with T.block(&quot;root&quot;):</span>\n",
       "        <span style=\"color: #008000; font-weight: bold\">for</span> i, j <span style=\"color: #008000; font-weight: bold\">in</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>grid(<span style=\"color: #008000\">16</span>, <span style=\"color: #008000\">16</span>):\n",
       "            <span style=\"color: #008000; font-weight: bold\">with</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>block(<span style=\"color: #BA2121\">&quot;identity&quot;</span>):\n",
       "                vi, vj <span style=\"color: #AA22FF; font-weight: bold\">=</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>axis<span style=\"color: #AA22FF; font-weight: bold\">.</span>remap(<span style=\"color: #BA2121\">&quot;SS&quot;</span>, [i, j])\n",
       "                T<span style=\"color: #AA22FF; font-weight: bold\">.</span>reads(A[vi, vj])\n",
       "                T<span style=\"color: #AA22FF; font-weight: bold\">.</span>writes(B[vi, vj])\n",
       "                B[vi, vj] <span style=\"color: #AA22FF; font-weight: bold\">=</span> A[vi, vj]\n",
       "\n",
       "    <span style=\"color: #AA22FF\">@R</span><span style=\"color: #AA22FF; font-weight: bold\">.</span>function\n",
       "    <span style=\"color: #008000; font-weight: bold\">def</span> <span style=\"color: #0000FF\">main</span>() <span style=\"color: #AA22FF; font-weight: bold\">-&gt;</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">16</span>, <span style=\"color: #008000\">16</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>):\n",
       "        <span style=\"color: #008000; font-weight: bold\">return</span> metadata[<span style=\"color: #BA2121\">&quot;relax.expr.Constant&quot;</span>][<span style=\"color: #008000\">0</span>]\n",
       "\n",
       "<span style=\"color: #007979; font-style: italic\"># Metadata omitted. Use show_meta=True in script() method to show it.</span>\n",
       "</pre></div>\n"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "c0_np = np.arange((16 * 16)).astype(\"float32\").reshape(16, 16)\n",
    "c1_np = c0_np + 1\n",
    "c2_np = c0_np - c1_np\n",
    "\n",
    "before.show()\n",
    "after = relax.transform.FoldConstant()(before)\n",
    "after.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## fold_shape_computation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "@I.ir_module\n",
    "class Module:\n",
    "    @R.function\n",
    "    def before(\n",
    "        data: R.Tensor((5, 4, 3, 2), dtype=\"float32\"),\n",
    "        indices: R.Tensor((1,), dtype=\"int64\"),\n",
    "    ) -> R.Tensor((1, 1), dtype=\"int64\"):\n",
    "        with R.dataflow():\n",
    "            lv: R.Tensor((4,), dtype=\"int64\") = R.shape_to_tensor(R.shape([5, 4, 3, 2]))\n",
    "            lv1: R.Tensor((1,), dtype=\"int64\") = R.take(lv, indices, axis=0)\n",
    "            lv2: R.Tensor((1, 1), dtype=\"int64\") = R.expand_dims(lv1, axis=[0])\n",
    "            gv: R.Tensor((1, 1), dtype=\"int64\") = R.concat((lv2,), axis=0)\n",
    "            R.output(gv)\n",
    "        return gv"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "tags": [
     "hide-output"
    ]
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div class=\"highlight\" style=\"background: \"><pre style=\"line-height: 125%;\"><span></span><span style=\"color: #007979; font-style: italic\"># from tvm.script import ir as I</span>\n",
       "<span style=\"color: #007979; font-style: italic\"># from tvm.script import relax as R</span>\n",
       "\n",
       "<span style=\"color: #AA22FF\">@I</span><span style=\"color: #AA22FF; font-weight: bold\">.</span>ir_module\n",
       "<span style=\"color: #008000; font-weight: bold\">class</span> <span style=\"color: #0000FF; font-weight: bold\">Module</span>:\n",
       "    <span style=\"color: #AA22FF\">@R</span><span style=\"color: #AA22FF; font-weight: bold\">.</span>function\n",
       "    <span style=\"color: #008000; font-weight: bold\">def</span> <span style=\"color: #0000FF\">main</span>(data: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">5</span>, <span style=\"color: #008000\">4</span>, <span style=\"color: #008000\">3</span>, <span style=\"color: #008000\">2</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>)) <span style=\"color: #AA22FF; font-weight: bold\">-&gt;</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">1</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;int64&quot;</span>):\n",
       "        <span style=\"color: #008000; font-weight: bold\">with</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>dataflow():\n",
       "            lv: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">4</span>,), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;int64&quot;</span>) <span style=\"color: #AA22FF; font-weight: bold\">=</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>shape_to_tensor(R<span style=\"color: #AA22FF; font-weight: bold\">.</span>shape([<span style=\"color: #008000\">5</span>, <span style=\"color: #008000\">4</span>, <span style=\"color: #008000\">3</span>, <span style=\"color: #008000\">2</span>]))\n",
       "            lv1: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">1</span>,), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;int64&quot;</span>) <span style=\"color: #AA22FF; font-weight: bold\">=</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>take(lv, metadata[<span style=\"color: #BA2121\">&quot;relax.expr.Constant&quot;</span>][<span style=\"color: #008000\">0</span>], axis<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #008000\">0</span>)\n",
       "            lv2: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">1</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;int64&quot;</span>) <span style=\"color: #AA22FF; font-weight: bold\">=</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>expand_dims(lv1, axis<span style=\"color: #AA22FF; font-weight: bold\">=</span>[<span style=\"color: #008000\">0</span>])\n",
       "            gv: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">1</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;int64&quot;</span>) <span style=\"color: #AA22FF; font-weight: bold\">=</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>concat((lv2,), axis<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #008000\">0</span>)\n",
       "            R<span style=\"color: #AA22FF; font-weight: bold\">.</span>output(gv)\n",
       "        <span style=\"color: #008000; font-weight: bold\">return</span> gv\n",
       "\n",
       "<span style=\"color: #007979; font-style: italic\"># Metadata omitted. Use show_meta=True in script() method to show it.</span>\n",
       "</pre></div>\n"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<div class=\"highlight\" style=\"background: \"><pre style=\"line-height: 125%;\"><span></span><span style=\"color: #007979; font-style: italic\"># from tvm.script import ir as I</span>\n",
       "<span style=\"color: #007979; font-style: italic\"># from tvm.script import relax as R</span>\n",
       "\n",
       "<span style=\"color: #AA22FF\">@I</span><span style=\"color: #AA22FF; font-weight: bold\">.</span>ir_module\n",
       "<span style=\"color: #008000; font-weight: bold\">class</span> <span style=\"color: #0000FF; font-weight: bold\">Module</span>:\n",
       "    <span style=\"color: #AA22FF\">@R</span><span style=\"color: #AA22FF; font-weight: bold\">.</span>function\n",
       "    <span style=\"color: #008000; font-weight: bold\">def</span> <span style=\"color: #0000FF\">main</span>(data: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">5</span>, <span style=\"color: #008000\">4</span>, <span style=\"color: #008000\">3</span>, <span style=\"color: #008000\">2</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>)) <span style=\"color: #AA22FF; font-weight: bold\">-&gt;</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">1</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;int64&quot;</span>):\n",
       "        <span style=\"color: #008000; font-weight: bold\">return</span> metadata[<span style=\"color: #BA2121\">&quot;relax.expr.Constant&quot;</span>][<span style=\"color: #008000\">0</span>]\n",
       "\n",
       "<span style=\"color: #007979; font-style: italic\"># Metadata omitted. Use show_meta=True in script() method to show it.</span>\n",
       "</pre></div>\n"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "before = gen_mod(Module, \"before\", {\"indices\": tvm.nd.array(np.array([0]).astype(\"int64\"))})\n",
    "before.show()\n",
    "after = relax.transform.FoldConstant()(before)\n",
    "after.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "ai",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
